{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# 自定义层\n",
    "\n",
    "深度学习的一个魅力在于神经网络中各式各样的层，例如全连接层和后面章节中将要介绍的卷积层、池化层与循环层。虽然`torch.nn`提供了大量常用的层，但有时候我们依然希望自定义层。本节将介绍如何使用`torch`来自定义一个网络层，从而可以被重复调用。\n",
    "\n",
    "\n",
    "## 不含模型参数的自定义层\n",
    "\n",
    "我们先介绍如何定义一个不含模型参数的自定义层。事实上，这和[“模型构造”](model-construction.ipynb)一节中介绍的使用`Module`类构造模型类似。下面的`CenteredLayer`类通过继承`Module`类自定义了一个将输入减掉均值后输出的层，并将层的计算定义在了`forward`函数里。这个层里不含模型参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "from torch import nn\n",
    "\n",
    "class CenteredLayer(nn.Module):\n",
    "    def __init__(self, **kwargs):\n",
    "        super(CenteredLayer, self).__init__(**kwargs)\n",
    "    \n",
    "    def forward(self, x):\n",
    "        return x - x.mean()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以实例化这个层，然后做前向计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([-2., -1.,  0.,  1.,  2.])"
      ]
     },
     "execution_count": 2,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "layer = CenteredLayer()\n",
    "layer(torch.Tensor([1, 2, 3, 4, 5]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们也可以用它来构造更复杂的模型。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "net = nn.Sequential(\n",
    "    nn.Linear(8, 128),\n",
    "    CenteredLayer()\n",
    ")"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面打印自定义层各个输出的均值。因为均值是浮点数，所以它的值是一个很接近0的数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "7.450580596923828e-09"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "y = net(torch.rand(4, 8))\n",
    "y.mean().item()"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 含模型参数的自定义层\n",
    "\n",
    "我们还可以自定义含模型参数的自定义层。其中的模型参数可以通过训练学出。\n",
    "\n",
    "对于模型的可学习参数，需要使用`slef.param_name = nn.Parameter(初始化形式)`进行创建。当Paramenter赋值给Module的属性的时候，他会自动的被加到 Module的 参数列表中(即：会出现在  parameters() 迭代器中)，并且参数名为`param_name`。"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "现在我们尝试实现一个含权重参数和偏差参数的全连接层。它使用ReLU函数作为激活函数。其中`in_units`和`units`分别代表输入个数和输出个数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [],
   "source": [
    "class MyLinear(nn.Module):\n",
    "    # in_features为该层的输入特征数, out_features为该层的输出特征数\n",
    "    def __init__(self, in_features, out_features, **kwargs):\n",
    "        super(MyLinear, self).__init__(**kwargs)\n",
    "        self.weight = nn.Parameter(torch.rand(in_features, out_features))\n",
    "        self.bias = nn.Parameter(torch.rand(out_features))\n",
    "    \n",
    "    def forward(self, x):\n",
    "        linear = x.mm(self.weight) + self.bias\n",
    "        return torch.relu(linear)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "下面，我们实例化`MyDense`类并访问它的模型参数。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "weight torch.Size([5, 3]) torch.float32\n",
      "bias torch.Size([3]) torch.float32\n"
     ]
    }
   ],
   "source": [
    "linear = MyLinear(in_features=5, out_features=3)\n",
    "for name, param in linear.named_parameters():\n",
    "    print(name, param.shape, param.dtype)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们可以直接使用自定义层做前向计算。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[0.5579, 1.2363, 1.2546],\n",
       "        [0.7231, 1.3465, 1.1097]], grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 7,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "linear(torch.rand(2, 5))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "我们也可以使用自定义层构造模型。它和PyTorch的其他层在使用上很类似。"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "tensor([[39.8864],\n",
       "        [39.3420]], grad_fn=<ReluBackward0>)"
      ]
     },
     "execution_count": 8,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "net = nn.Sequential(\n",
    "    MyLinear(64, 8),\n",
    "    MyLinear(8, 1)\n",
    ")\n",
    "\n",
    "net(torch.rand(2, 64))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## 小结\n",
    "\n",
    "* 可以通过`Module`类自定义神经网络中的层，从而可以被重复调用。\n",
    "\n",
    "\n",
    "## 练习\n",
    "\n",
    "* 自定义一个层，使用它做一次前向计算。\n",
    "\n",
    "\n",
    "\n",
    "\n",
    "## 扫码直达[讨论区](https://discuss.gluon.ai/t/topic/1256)\n",
    "\n",
    "![](../img/qr_custom-layer.svg)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python [conda env:pytorch]",
   "language": "python",
   "name": "conda-env-pytorch-py"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.9"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
